{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 2: Secure Model Serving with TFE Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that you have a trained model with normal Keras, you are ready to serve some private predictions. We can do that using TFE Keras.\n",
    "\n",
    "To secure and serve this model, we will need three TFE servers. This is because TF Encrypted under the hood uses an encryption technique called [multi-party computation (MPC)](https://en.wikipedia.org/wiki/Secure_multi-party_computation). The idea is to split the model weights and input data into shares, then send a share of each value to the different servers. The key property is that if you look at the share on one server, it reveals nothing about the original value (input data or model weights).\n",
    "\n",
    "If you want to learn more about MPC, you can read this excellent [blog](https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/).\n",
    "\n",
    "In this notebook, you will be able serve private predictions after a ssries of simple steps:\n",
    "- Configure TFE Protocol to secure the model via secret sharing.\n",
    "- Launch three TFE servers.\n",
    "- Convert the TF Keras model into a TFE Keras model using `tfe.keras.models.clone_model`.\n",
    "- Serve the secured model using `tfe.serving.QueueServer`.\n",
    "\n",
    "Alright, let's do it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-11-01 10:03:13.236530: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2022-11-01 10:03:13.240692: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
      "2022-11-01 10:03:13.240705: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n",
      "2022-11-01 10:03:14.534770: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n",
      "2022-11-01 10:03:14.534796: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2022-11-01 10:03:14.534809: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (mpc-empty011122134151.ea134): /proc/driver/nvidia/version does not exist\n",
      "2022-11-01 10:03:14.535082: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import tf_encrypted as tfe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, we define almost the exact same model as before, except we provide a `batch_input_shape`. This allows TF Encrypted to better optimize the secure computations via predefined tensor shapes. For this MNIST demo, we'll send input data with the shape of (1, 28, 28, 1). \n",
    "We also return the logit instead of softmax because this operation is complex to perform using MPC, and we don't need it to serve prediction requests."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 10\n",
    "input_shape = (1, 28, 28, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "          tf.keras.layers.Conv2D(16, 8,\n",
    "                                 strides=2,\n",
    "                                 padding='same',\n",
    "                                 activation='relu',\n",
    "                                 batch_input_shape=input_shape),\n",
    "          tf.keras.layers.AveragePooling2D(2, 1),\n",
    "          tf.keras.layers.Conv2D(32, 4,\n",
    "                                 strides=2,\n",
    "                                 padding='valid',\n",
    "                                 activation='relu'),\n",
    "          tf.keras.layers.AveragePooling2D(2, 1),\n",
    "          tf.keras.layers.Flatten(),\n",
    "          tf.keras.layers.Dense(32, activation='relu'),\n",
    "          tf.keras.layers.Dense(num_classes, name='logit')\n",
    "  ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `load_weights` you can easily load the weights you have saved previously after training your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_trained_weights = 'short-dnn.h5'\n",
    "model.load_weights(pre_trained_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Launching servers\n",
    "\n",
    "Before actually serving the computation below we need to launch TFE servers in new processes. Run the following in three different terminals. You may have to allow Python to accept incoming connections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "python -m tf_encrypted.player --config ./tfe.config server0\n",
      "python -m tf_encrypted.player --config ./tfe.config server1\n",
      "python -m tf_encrypted.player --config ./tfe.config server2\n"
     ]
    }
   ],
   "source": [
    "players = OrderedDict([\n",
    "    ('server0', 'localhost:4000'),\n",
    "    ('server1', 'localhost:4001'),\n",
    "    ('server2', 'localhost:4002'),\n",
    "])\n",
    "\n",
    "for player_name in players.keys():\n",
    "    print(\"python -m tf_encrypted.player --config ./tfe.config {}\".format(player_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Protocol"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first configure the servers on which we want to run it, as well as the protocol we will be using. We will be using the SecureNN protocol to secret share the model between each of the three TFE servers. Most importantly, this will add the capability of providing predictions on encrypted data.\n",
    "\n",
    "Note that the configuration is saved to file as we will be needing it in the client as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-11-01 10:03:14.958781: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job tfe -> {0 -> localhost:4000, 1 -> localhost:4001, 2 -> localhost:4002}\n",
      "2022-11-01 10:03:14.958807: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job localhost -> {0 -> localhost:42409}\n",
      "2022-11-01 10:03:14.967083: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job tfe -> {0 -> localhost:4000, 1 -> localhost:4001, 2 -> localhost:4002}\n",
      "2022-11-01 10:03:14.967100: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job localhost -> {0 -> localhost:42409}\n",
      "2022-11-01 10:03:14.967722: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:438] Started server with target: grpc://localhost:42409\n"
     ]
    }
   ],
   "source": [
    "config = tfe.RemoteConfig(players)\n",
    "config.save('./tfe.config')\n",
    "config.connect_servers()\n",
    "tfe.set_config(config)\n",
    "tfe.set_protocol(tfe.protocol.SecureNN())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert TF Keras into TFE Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Thanks to `tfe.keras.models.clone_model` you can convert automatically the TF Keras model into a TFE Keras model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tfe.protocol.SecureNN():\n",
    "    tfe_model = tfe.keras.models.clone_model(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up a new `tfe.serving.QueueServer` "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`tfe.serving.QueueServer` will launch a serving queue, so that the TFE servers can accept prediction requests on the secured model from external clients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up a new tfe.serving.QueueServer for the shared TFE model\n",
    "q_input_shape = (1, 28, 28, 1)\n",
    "q_output_shape = (1, 10)\n",
    "\n",
    "server = tfe.serving.QueueServer(\n",
    "    input_shape=q_input_shape, output_shape=q_output_shape, computation_fn=tfe_model\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Server"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perfect! with all of the above in place we can finally push our TensorFlow graph to them, and start serving the model. You can set `num_requests` to set a limit on the number of predictions requests served by the model; if not specified then the model will be served until interrupted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Served encrypted prediction 1 to client.\n",
      "Served encrypted prediction 2 to client.\n",
      "Served encrypted prediction 3 to client.\n"
     ]
    }
   ],
   "source": [
    "request_ix = 1\n",
    "\n",
    "def step_fn():\n",
    "    global request_ix\n",
    "    print(\"Served encrypted prediction {i} to client.\".format(i=request_ix))\n",
    "    request_ix += 1\n",
    "\n",
    "server.run(num_steps=3, step_fn=step_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You are ready to move to the **c - Private Prediction Client** notebook to request some private predictions. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cleanup!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once your request limit above, the model will no longer be available for serving requests, but it's still secret shared between the three workers above. You can kill the workers by executing the cell below.\n",
    "\n",
    "**Congratulations** on finishing b - Secure Model Serving."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Process ID 177547 has been killed.\n",
      "Process ID 178030 has been killed.\n",
      "Process ID 178508 has been killed.\n"
     ]
    }
   ],
   "source": [
    "process_ids = !ps aux | grep '[p]ython -m tf_encrypted.player --config' | awk '{print $2}'\n",
    "for process_id in process_ids:\n",
    "    !kill {process_id}\n",
    "    print(\"Process ID {id} has been killed.\".format(id=process_id))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('tfe2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "2b3cf1a821f25a2bafd97d4b4f3a77d0e1f42d8d78c0e512d3a72b703c70416e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
